import os
import random
import xml.etree.ElementTree as ET

import numpy as np

from utils.utils import get_classes

# --------------------------------------------------------------------------------------------------------------------------------#
#   annotation_mode用于指定该文件运行时计算的内容
#   annotation_mode为0代表整个标签处理过程，包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt
#   annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt
#   annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt
# --------------------------------------------------------------------------------------------------------------------------------#
annotation_mode = 0
# -------------------------------------------------------------------#
#   必须要修改，用于生成2007_train.txt、2007_val.txt的目标信息
#   与训练和预测所用的classes_path一致即可
#   如果生成的2007_train.txt里面没有目标信息
#   那么就是因为classes没有设定正确
#   仅在annotation_mode为0和2的时候有效
# -------------------------------------------------------------------#
classes_path = "model_data/voc_classes.txt"
# --------------------------------------------------------------------------------------------------------------------------------#
#   trainval_percent用于指定(训练集+验证集)与测试集的比例，默认情况下 (训练集+验证集):测试集 = 9:1
#   train_percent用于指定(训练集+验证集)中训练集与验证集的比例，默认情况下 训练集:验证集 = 9:1
#   仅在annotation_mode为0和1的时候有效
# --------------------------------------------------------------------------------------------------------------------------------#
trainval_percent = 0.9
train_percent = 0.9
# -------------------------------------------------------#
#   指向VOC数据集所在的文件夹
#   默认指向根目录下的VOC数据集
# -------------------------------------------------------#
VOCdevkit_path = "VOCdevkit"
# 获取当前路径
VOCdevkit_path = os.path.abspath(VOCdevkit_path)

VOCdevkit_sets = [("2007", "train"), ("2007", "val")]
classes, _ = get_classes(classes_path)

# -------------------------------------------------------#
#   统计目标数量
# -------------------------------------------------------#
photo_nums = np.zeros(len(VOCdevkit_sets))
nums = np.zeros(len(classes))


def convert_annotation(year, image_id, list_file):
    # in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml'%(year, image_id)), encoding='utf-8')
    in_file = open(
        os.path.join(VOCdevkit_path, "merge/Annotations/%s.xml" % (image_id)).replace(
            "\\", "/"
        ),
        encoding="utf-8",
    )
    tree = ET.parse(in_file)
    root = tree.getroot()

    for obj in root.iter("object"):
        difficult = 0
        if obj.find("difficult") != None:
            difficult = obj.find("difficult").text
        cls = obj.find("name").text
        if cls not in classes or int(difficult) == 1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find("bndbox")
        b = (
            int(float(xmlbox.find("xmin").text)),
            int(float(xmlbox.find("ymin").text)),
            int(float(xmlbox.find("xmax").text)),
            int(float(xmlbox.find("ymax").text)),
        )
        list_file.write(" " + ",".join([str(a) for a in b]) + "," + str(cls_id))

        nums[classes.index(cls)] = nums[classes.index(cls)] + 1


if __name__ == "__main__":
    random.seed(0)
    if " " in os.path.abspath(VOCdevkit_path):
        raise ValueError(
            "数据集存放的文件夹路径与图片名称中不可以存在空格，否则会影响正常的模型训练，请注意修改。"
        )

    if annotation_mode == 0 or annotation_mode == 1:
        print("Generate txt in ImageSets.")
        xmlfilepath = os.path.join(VOCdevkit_path, "merge/Annotations")
        saveBasePath = os.path.join(VOCdevkit_path, "merge/ImageSets/Main")
        temp_xml = os.listdir(xmlfilepath)
        total_xml = []
        for xml in temp_xml:
            if xml.endswith(".xml"):
                total_xml.append(xml)

        num = len(total_xml)
        list = range(num)
        tv = int(num * trainval_percent)
        tr = int(tv * train_percent)
        trainval = random.sample(list, tv)
        train = random.sample(trainval, tr)

        print("train and val size", tv)
        print("train size", tr)
        ftrainval = open(os.path.join(saveBasePath, "trainval.txt"), "w")
        ftest = open(os.path.join(saveBasePath, "test.txt"), "w")
        ftrain = open(os.path.join(saveBasePath, "train.txt"), "w")
        fval = open(os.path.join(saveBasePath, "val.txt"), "w")

        for i in list:
            name = total_xml[i][:-4] + "\n"
            if i in trainval:
                ftrainval.write(name)
                if i in train:
                    ftrain.write(name)
                else:
                    fval.write(name)
            else:
                ftest.write(name)

        ftrainval.close()
        ftrain.close()
        fval.close()
        ftest.close()
        print("Generate txt in ImageSets done.")

    if annotation_mode == 0 or annotation_mode == 2:
        print("Generate 2007_train.txt and 2007_val.txt for train.")
        type_index = 0
        for year, image_set in VOCdevkit_sets:
            # image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt'%(year, image_set)), encoding='utf-8').read().strip().split()
            image_id_file = open(
                os.path.join(
                    VOCdevkit_path, "merge/ImageSets/Main/%s.txt" % (image_set)
                ).replace("\\", "/"),
                encoding="utf-8",
            )
            # 根据回车回去image_id_file id
            image_ids = image_id_file.read().strip().split("\n")
            list_file = open("%s_%s.txt" % (year, image_set), "w", encoding="utf-8")
            print("image_ids", image_ids)
            for image_id in image_ids:
                list_file.write(
                    "%s/merge/JPEGImages/%s.jpg"
                    % (os.path.abspath(VOCdevkit_path), image_id)
                )

                convert_annotation(year, image_id, list_file)
                list_file.write("\n")
            photo_nums[type_index] = len(image_ids)
            type_index += 1
            list_file.close()
        print("Generate 2007_train.txt and 2007_val.txt for train done.")

        def printTable(List1, List2):
            for i in range(len(List1[0])):
                print("|", end=" ")
                for j in range(len(List1)):
                    print(List1[j][i].rjust(int(List2[j])), end=" ")
                    print("|", end=" ")
                print()

        str_nums = [str(int(x)) for x in nums]
        tableData = [classes, str_nums]
        colWidths = [0] * len(tableData)
        len1 = 0
        for i in range(len(tableData)):
            for j in range(len(tableData[i])):
                if len(tableData[i][j]) > colWidths[i]:
                    colWidths[i] = len(tableData[i][j])
        printTable(tableData, colWidths)

        if photo_nums[0] <= 500:
            print(
                "训练集数量小于500，属于较小的数据量，请注意设置较大的训练世代（Epoch）以满足足够的梯度下降次数（Step）。"
            )

        if np.sum(nums) == 0:
            print(
                "在数据集中并未获得任何目标，请注意修改classes_path对应自己的数据集，并且保证标签名字正确，否则训练将会没有任何效果！"
            )
            print(
                "在数据集中并未获得任何目标，请注意修改classes_path对应自己的数据集，并且保证标签名字正确，否则训练将会没有任何效果！"
            )
            print(
                "在数据集中并未获得任何目标，请注意修改classes_path对应自己的数据集，并且保证标签名字正确，否则训练将会没有任何效果！"
            )
            print("（重要的事情说三遍）。")
